{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "SocialAttentionDQN",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "source": [],
        "metadata": {
          "collapsed": false
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sepDWoBqdRMK"
      },
      "source": [
        "# Training a DQN with social attention on `intersection-v0`\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Kx8X4s8krNWt"
      },
      "source": [
        "#@title Import requirements\n",
        "\n",
        "# Environment\n",
        "!pip install highway-env\n",
        "import gymnasium as gym\n",
        "\n",
        "# Agent\n",
        "!pip install git+https://github.com/eleurent/rl-agents#egg=rl-agents\n",
        "\n",
        "# Visualisation utils\n",
        "!pip install moviepy\n",
        "!pip install imageio_ffmpeg\n",
        "import sys\n",
        "%load_ext tensorboard\n",
        "!pip install tensorboardx gym pyvirtualdisplay\n",
        "!apt-get install -y xvfb python-opengl ffmpeg\n",
        "!git clone https://github.com/eleurent/highway-env.git 2> /dev/null\n",
        "sys.path.insert(0, '/content/highway-env/scripts/')\n",
        "from utils import show_videos"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vvOEW00pdHrG"
      },
      "source": [
        "## Training\n",
        "\n",
        "We use a policy architecture based on social attention, see [[Leurent and Mercat, 2019]](https://arxiv.org/abs/1911.12250).\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QowKW3ix45ZW"
      },
      "source": [
        "#@title Prepare environment, agent, and evaluation process.\n",
        "\n",
        "NUM_EPISODES = 3000  #@param {type: \"integer\"}\n",
        "\n",
        "from rl_agents.trainer.evaluation import Evaluation\n",
        "from rl_agents.agents.common.factory import load_agent, load_environment\n",
        "\n",
        "# Get the environment and agent configurations from the rl-agents repository\n",
        "!git clone https://github.com/eleurent/rl-agents.git 2> /dev/null\n",
        "%cd /content/rl-agents/scripts/\n",
        "env_config = 'configs/IntersectionEnv/env.json'\n",
        "agent_config = 'configs/IntersectionEnv/agents/DQNAgent/ego_attention_2h.json'\n",
        "\n",
        "env = load_environment(env_config)\n",
        "agent = load_agent(agent_config, env)\n",
        "evaluation = Evaluation(env, agent, num_episodes=NUM_EPISODES, display_env=False, display_agent=False)\n",
        "print(f\"Ready to train {agent} on {env}\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nqnGqW6jd1xN"
      },
      "source": [
        "Run tensorboard locally to visualize training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q7QJY2wc4_1N"
      },
      "source": [
        "%tensorboard --logdir \"{evaluation.directory}\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BtK9dtfb0JMF"
      },
      "source": [
        "Start training. This should take about an hour."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sFVq1gFz42Eg"
      },
      "source": [
        "evaluation.train()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-lNvWg42RWiw"
      },
      "source": [
        "Progress can be visualised in the tensorboard cell above, which should update every 30s (or manually). You may need to click the *Fit domain to data* buttons below each graph."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VKfvu5uhzCIU"
      },
      "source": [
        "## Testing"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gY0rpVYUtRpN"
      },
      "source": [
        "#@title Run the learned policy for a few episodes.\n",
        "env = load_environment(env_config)\n",
        "env.config[\"offscreen_rendering\"] = True\n",
        "agent = load_agent(agent_config, env)\n",
        "evaluation = Evaluation(env, agent, num_episodes=1)\n",
        "evaluation.train()\n",
        "show_videos(evaluation.run_directory)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}